{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow：静态图"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PyTorch自动求导看起来非常像TensorFlow：这两个框架中，我们都定义计算图，使用自动微分来计算梯度。两者最大的不同就是TensorFlow的计算图是**静态的**，而PyTorch使用**动态的**计算图。 \n",
    "\n",
    "在TensorFlow中，我们定义计算图一次，然后重复执行这个相同的图，可能会提供不同的输入数据。而在PyTorch中，每一个前向通道定义一个新的计算图。 \n",
    "\n",
    "静态图的好处在于你可以预先对图进行优化。例如，一个框架可能要融合一些图的运算来提升效率，或者产生一个策略来将图分布到多个GPU或机器上。如果重复使用相同的图，那么在重复运行同一个图时，，前期潜在的代价高昂的预先优化的消耗就会被分摊开。\n",
    "\n",
    "静态图和动态图的一个区别是控制流。对于一些模型，我们希望对每个数据点执行不同的计算。例如，一个递归神经网络可能对于每个数据点执行不同的时间步数，这个展开（unrolling）可以作为一个循环来实现。对于一个静态图，循环结构要作为图的一部分。因此，TensorFlow提供了运算符（例如`tf.scan`）来把循环嵌入到图当中。对于动态图来说，情况更加简单：既然我们为每个例子即时创建图，我们可以使用普通的命令式控制流来为每个输入执行不同的计算。 \n",
    "\n",
    "为了与上面的PyTorch自动梯度实例做对比，我们使用TensorFlow来拟合一个简单的2层网络："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "40589096.0\n",
      "37821904.0\n",
      "36510708.0\n",
      "30501270.0\n",
      "20827848.0\n",
      "11729495.0\n",
      "6143866.0\n",
      "3364360.8\n",
      "2087345.2\n",
      "1463462.9\n",
      "1118738.6\n",
      "899005.56\n",
      "742703.4\n",
      "623438.25\n",
      "528738.25\n",
      "451812.38\n",
      "388369.94\n",
      "335582.6\n",
      "291366.1\n",
      "254078.58\n",
      "222403.28\n",
      "195346.31\n",
      "172135.34\n",
      "152153.95\n",
      "134849.33\n",
      "119816.305\n",
      "106711.28\n",
      "95249.88\n",
      "85200.125\n",
      "76371.72\n",
      "68593.875\n",
      "61719.89\n",
      "55627.27\n",
      "50220.57\n",
      "45409.52\n",
      "41122.0\n",
      "37293.066\n",
      "33866.797\n",
      "30797.146\n",
      "28041.08\n",
      "25562.172\n",
      "23329.615\n",
      "21317.312\n",
      "19499.893\n",
      "17857.492\n",
      "16370.625\n",
      "15022.511\n",
      "13799.208\n",
      "12687.409\n",
      "11676.729\n",
      "10756.279\n",
      "9917.69\n",
      "9151.785\n",
      "8452.016\n",
      "7811.713\n",
      "7225.5547\n",
      "6688.344\n",
      "6195.357\n",
      "5742.7827\n",
      "5327.023\n",
      "4944.6875\n",
      "4592.6323\n",
      "4268.4062\n",
      "3969.478\n",
      "3693.717\n",
      "3439.3118\n",
      "3204.3713\n",
      "2986.9905\n",
      "2785.916\n",
      "2599.763\n",
      "2427.266\n",
      "2266.9893\n",
      "2118.3594\n",
      "1980.4802\n",
      "1852.3848\n",
      "1733.4098\n",
      "1622.8286\n",
      "1520.0109\n",
      "1424.2677\n",
      "1335.126\n",
      "1252.0669\n",
      "1174.6736\n",
      "1102.5017\n",
      "1035.1385\n",
      "972.2462\n",
      "913.5196\n",
      "858.66516\n",
      "807.37445\n",
      "759.43335\n",
      "714.56573\n",
      "672.57227\n",
      "633.2573\n",
      "596.42523\n",
      "561.9042\n",
      "529.5462\n",
      "499.20865\n",
      "470.73645\n",
      "444.01196\n",
      "418.92288\n",
      "395.37564\n",
      "373.24405\n",
      "352.44128\n",
      "332.88275\n",
      "314.4971\n",
      "297.20413\n",
      "280.92273\n",
      "265.59735\n",
      "251.1689\n",
      "237.57965\n",
      "224.78468\n",
      "212.71817\n",
      "201.34525\n",
      "190.62335\n",
      "180.51225\n",
      "170.9723\n",
      "161.96956\n",
      "153.47028\n",
      "145.44525\n",
      "137.87146\n",
      "130.71455\n",
      "123.95056\n",
      "117.56006\n",
      "111.51813\n",
      "105.80647\n",
      "100.403564\n",
      "95.29393\n",
      "90.45943\n",
      "85.88639\n",
      "81.55574\n",
      "77.45676\n",
      "73.57431\n",
      "69.89832\n",
      "66.41628\n",
      "63.116283\n",
      "59.98909\n",
      "57.025627\n",
      "54.215267\n",
      "51.55214\n",
      "49.025974\n",
      "46.630264\n",
      "44.357086\n",
      "42.200153\n",
      "40.15351\n",
      "38.210358\n",
      "36.36695\n",
      "34.615128\n",
      "32.953022\n",
      "31.373756\n",
      "29.874235\n",
      "28.449154\n",
      "27.094942\n",
      "25.809193\n",
      "24.586369\n",
      "23.423946\n",
      "22.318218\n",
      "21.26784\n",
      "20.268692\n",
      "19.318031\n",
      "18.41416\n",
      "17.554228\n",
      "16.73655\n",
      "15.957744\n",
      "15.217113\n",
      "14.512032\n",
      "13.840776\n",
      "13.201844\n",
      "12.593298\n",
      "12.013601\n",
      "11.462084\n",
      "10.936723\n",
      "10.436237\n",
      "9.959286\n",
      "9.504883\n",
      "9.071983\n",
      "8.659666\n",
      "8.266763\n",
      "7.892255\n",
      "7.5349283\n",
      "7.1944885\n",
      "6.869582\n",
      "6.5602045\n",
      "6.2650437\n",
      "5.9837112\n",
      "5.715268\n",
      "5.459203\n",
      "5.2152414\n",
      "4.982437\n",
      "4.760192\n",
      "4.54817\n",
      "4.3455434\n",
      "4.1526113\n",
      "3.9683657\n",
      "3.7924674\n",
      "3.6245632\n",
      "3.4643197\n",
      "3.3112822\n",
      "3.16539\n",
      "3.0258243\n",
      "2.8929996\n",
      "2.7656598\n",
      "2.6443655\n",
      "2.52833\n",
      "2.417658\n",
      "2.3118794\n",
      "2.2107594\n",
      "2.1143537\n",
      "2.022164\n",
      "1.9340122\n",
      "1.8497837\n",
      "1.7694559\n",
      "1.6926887\n",
      "1.6192617\n",
      "1.549055\n",
      "1.4818767\n",
      "1.4178338\n",
      "1.3565595\n",
      "1.2979827\n",
      "1.2420561\n",
      "1.188442\n",
      "1.1373023\n",
      "1.0883944\n",
      "1.0415248\n",
      "0.9968837\n",
      "0.95416886\n",
      "0.9131793\n",
      "0.8740844\n",
      "0.8366753\n",
      "0.80090094\n",
      "0.7666666\n",
      "0.73398066\n",
      "0.7026192\n",
      "0.67268026\n",
      "0.64399487\n",
      "0.6165598\n",
      "0.5903982\n",
      "0.56529105\n",
      "0.5412486\n",
      "0.5183328\n",
      "0.4963334\n",
      "0.4754038\n",
      "0.45525733\n",
      "0.4360122\n",
      "0.4175854\n",
      "0.3999764\n",
      "0.38305098\n",
      "0.3668703\n",
      "0.35141098\n",
      "0.3365595\n",
      "0.32239074\n",
      "0.30886626\n",
      "0.2958819\n",
      "0.2834331\n",
      "0.27151662\n",
      "0.26012555\n",
      "0.24923459\n",
      "0.23878607\n",
      "0.22880428\n",
      "0.21921614\n",
      "0.21005003\n",
      "0.20125386\n",
      "0.19285272\n",
      "0.1848178\n",
      "0.17709963\n",
      "0.16970985\n",
      "0.16261409\n",
      "0.15585387\n",
      "0.1493829\n",
      "0.14317018\n",
      "0.13723016\n",
      "0.1314907\n",
      "0.12605126\n",
      "0.1208179\n",
      "0.11583139\n",
      "0.11101707\n",
      "0.10643093\n",
      "0.102020286\n",
      "0.09778917\n",
      "0.093738705\n",
      "0.08988893\n",
      "0.08614388\n",
      "0.08259517\n",
      "0.079182886\n",
      "0.07591591\n",
      "0.0727933\n",
      "0.06980319\n",
      "0.066904925\n",
      "0.06415297\n",
      "0.061523013\n",
      "0.058991455\n",
      "0.05656487\n",
      "0.054241817\n",
      "0.052009124\n",
      "0.0498687\n",
      "0.047838975\n",
      "0.045893796\n",
      "0.044006646\n",
      "0.042217474\n",
      "0.040483616\n",
      "0.038837142\n",
      "0.037256647\n",
      "0.03573991\n",
      "0.034271874\n",
      "0.032883074\n",
      "0.031539746\n",
      "0.030251838\n",
      "0.02902964\n",
      "0.027862137\n",
      "0.026728727\n",
      "0.025651127\n",
      "0.02460733\n",
      "0.023608057\n",
      "0.022666462\n",
      "0.021745328\n",
      "0.02086803\n",
      "0.020025901\n",
      "0.019226829\n",
      "0.018454809\n",
      "0.01771463\n",
      "0.017010594\n",
      "0.016334392\n",
      "0.015679605\n",
      "0.015054845\n",
      "0.014457512\n",
      "0.013883985\n",
      "0.013331914\n",
      "0.012802713\n",
      "0.01229547\n",
      "0.011808048\n",
      "0.01134156\n",
      "0.010894298\n",
      "0.010459685\n",
      "0.010051243\n",
      "0.00965915\n",
      "0.009278058\n",
      "0.008917944\n",
      "0.008568482\n",
      "0.00823351\n",
      "0.00791322\n",
      "0.007609226\n",
      "0.007313578\n",
      "0.007036432\n",
      "0.0067658797\n",
      "0.00650423\n",
      "0.006257597\n",
      "0.006018062\n",
      "0.0057887123\n",
      "0.0055777817\n",
      "0.0053665484\n",
      "0.005161192\n",
      "0.0049654115\n",
      "0.004782594\n",
      "0.0046034837\n",
      "0.0044314093\n",
      "0.004266924\n",
      "0.004108629\n",
      "0.003955093\n",
      "0.0038102847\n",
      "0.0036684272\n",
      "0.0035340425\n",
      "0.003405972\n",
      "0.0032828378\n",
      "0.003164256\n",
      "0.0030513133\n",
      "0.0029407754\n",
      "0.0028384721\n",
      "0.0027361098\n",
      "0.0026382185\n",
      "0.00254692\n",
      "0.002457071\n",
      "0.0023719957\n",
      "0.0022889297\n",
      "0.0022089095\n",
      "0.0021355487\n",
      "0.0020610488\n",
      "0.0019923944\n",
      "0.001924263\n",
      "0.0018614884\n",
      "0.0017974387\n",
      "0.0017387401\n",
      "0.0016805987\n",
      "0.0016271127\n",
      "0.0015741684\n",
      "0.0015242965\n",
      "0.0014746229\n",
      "0.0014274437\n",
      "0.0013824594\n",
      "0.0013388536\n",
      "0.001296876\n",
      "0.00125279\n",
      "0.0012153474\n",
      "0.0011787849\n",
      "0.0011418835\n",
      "0.0011073913\n",
      "0.0010736702\n",
      "0.0010400359\n",
      "0.0010100233\n",
      "0.0009806964\n",
      "0.00095189386\n",
      "0.0009247576\n",
      "0.0008982285\n",
      "0.00087163487\n",
      "0.0008472034\n",
      "0.0008231002\n",
      "0.00079884985\n",
      "0.00077604235\n",
      "0.00075402425\n",
      "0.0007340207\n",
      "0.0007141097\n",
      "0.00069453195\n",
      "0.0006752604\n",
      "0.0006563074\n",
      "0.0006406471\n",
      "0.00062222814\n",
      "0.00060647784\n",
      "0.0005901075\n",
      "0.00057464273\n",
      "0.0005608851\n",
      "0.00054660224\n",
      "0.0005313165\n",
      "0.0005183507\n",
      "0.0005054183\n",
      "0.00049229787\n",
      "0.00047964946\n",
      "0.00046927363\n",
      "0.00045756827\n",
      "0.00044661807\n",
      "0.00043569435\n",
      "0.00042536948\n",
      "0.0004148768\n",
      "0.00040489645\n",
      "0.0003955367\n",
      "0.00038605987\n",
      "0.00037693084\n",
      "0.00036768033\n",
      "0.00035982145\n",
      "0.00035199407\n",
      "0.00034366417\n",
      "0.00033610844\n",
      "0.00032827296\n",
      "0.00032113644\n",
      "0.00031387087\n",
      "0.0003071338\n",
      "0.00030077208\n",
      "0.00029348748\n",
      "0.0002872242\n",
      "0.00028121652\n",
      "0.00027480905\n",
      "0.0002700148\n",
      "0.00026436243\n",
      "0.00025887528\n",
      "0.00025346753\n",
      "0.00024772316\n",
      "0.00024315408\n",
      "0.00023835928\n",
      "0.0002334912\n",
      "0.000228311\n",
      "0.00022444785\n",
      "0.00021977603\n",
      "0.00021514796\n",
      "0.00021072998\n",
      "0.00020690367\n",
      "0.00020286307\n",
      "0.00019924743\n",
      "0.00019552487\n",
      "0.00019186549\n",
      "0.00018774494\n",
      "0.00018440498\n",
      "0.00018107734\n",
      "0.00017771842\n",
      "0.00017441464\n",
      "0.00017167685\n",
      "0.00016890634\n",
      "0.00016565295\n",
      "0.00016285216\n",
      "0.0001598886\n",
      "0.00015756505\n",
      "0.0001544411\n",
      "0.00015191144\n",
      "0.00014929647\n",
      "0.00014699742\n",
      "0.00014417457\n",
      "0.00014156639\n",
      "0.00013918911\n",
      "0.00013722279\n",
      "0.00013506513\n",
      "0.00013253758\n",
      "0.00013042818\n",
      "0.00012805709\n",
      "0.00012586355\n",
      "0.00012390668\n",
      "0.000122017074\n",
      "0.00012012456\n",
      "0.00011800403\n",
      "0.000116093375\n",
      "0.00011421463\n",
      "0.00011267222\n",
      "0.00011051622\n",
      "0.00010935906\n",
      "0.000107399384\n",
      "0.0001060903\n"
     ]
    }
   ],
   "source": [
    "# 可运行代码见本文件夹中的 tf_two_layer_net.py\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "# 首先我们建立计算图（computational graph）\n",
    "\n",
    "# N是批大小；D是输入维度；\n",
    "# H是隐藏层维度；D_out是输出维度。\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "# 为输入和目标数据创建placeholder；\n",
    "# 当执行计算图时，他们将会被真实的数据填充\n",
    "x = tf.placeholder(tf.float32, shape=(None, D_in))\n",
    "y = tf.placeholder(tf.float32, shape=(None, D_out))\n",
    "\n",
    "# 为权重创建Variable并用随机数据初始化\n",
    "# TensorFlow的Variable在执行计算图时不会改变\n",
    "w1 = tf.Variable(tf.random_normal((D_in, H)))\n",
    "w2 = tf.Variable(tf.random_normal((H, D_out)))\n",
    "\n",
    "# 前向传播：使用TensorFlow的张量运算计算预测值y。\n",
    "# 注意这段代码实际上不执行任何数值运算；\n",
    "# 它只是建立了我们稍后将执行的计算图。\n",
    "h = tf.matmul(x, w1)\n",
    "h_relu = tf.maximum(h, tf.zeros(1))\n",
    "y_pred = tf.matmul(h_relu, w2)\n",
    "\n",
    "# 使用TensorFlow的张量运算损失（loss）\n",
    "loss = tf.reduce_sum((y - y_pred) ** 2.0)\n",
    "\n",
    "# 计算loss对于w1和w2的导数\n",
    "grad_w1, grad_w2 = tf.gradients(loss, [w1, w2])\n",
    "\n",
    "# 使用梯度下降更新权重。为了实际更新权重，我们需要在执行计算图时计算new_w1和new_w2。\n",
    "# 注意，在TensorFlow中，更新权重值的行为是计算图的一部分;\n",
    "# 但在PyTorch中，这发生在计算图形之外。\n",
    "learning_rate = 1e-6\n",
    "new_w1 = w1.assign(w1 - learning_rate * grad_w1)\n",
    "new_w2 = w2.assign(w2 - learning_rate * grad_w2)\n",
    "\n",
    "# 现在我们搭建好了计算图，所以我们开始一个TensorFlow的会话（session）来实际执行计算图。\n",
    "with tf.Session() as sess:\n",
    "\n",
    "    # 运行一次计算图来初始化Variable w1和w2\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "\n",
    "    # 创建numpy数组来存储输入x和目标y的实际数据\n",
    "    x_value = np.random.randn(N, D_in)\n",
    "    y_value = np.random.randn(N, D_out)\n",
    "    \n",
    "    for _ in range(500):\n",
    "        # 多次运行计算图。每次执行时，我们都用feed_dict参数，\n",
    "        # 将x_value绑定到x，将y_value绑定到y，\n",
    "        # 每次执行图形时我们都要计算损失、new_w1和new_w2；\n",
    "        # 这些张量的值以numpy数组的形式返回。\n",
    "        loss_value, _, _ = sess.run([loss, new_w1, new_w2], \n",
    "                                    feed_dict={x: x_value, y: y_value})\n",
    "        print(loss_value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (Spyder)",
   "language": "python3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "227.797px"
   },
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
